#!/usr/bin/env bash
#
# run without epx:
# bash run_deepseekv2_epx.sh
#
# run with epx:
# bash run_deepseekv2_epx.sh -u --ccp_port 9008

# default parameters
USE_EPX=0                   # fault torerance
ROOT_PATH="/root/workspace" # default root path

# print usage
usage() {
    echo "You should prepare "llama2_dataset epx Megatron-LM megatron-lm-musa-patch" in the root path."
    echo "Usage: $0 [OPTIONS]"
    echo "Options:"
    echo "  --root_path        Set epx megatron-lm and megatron-lm-musa-patch root path (default: /root/workspace)."
    echo "  --dataset_path     Set dataset path.(default: /root/workspace/llama2_dataset)"
    echo "  -u, --use_epx      Set use_epx (default: 0)."
    echo "  --master_addr      Set master_addr (default: host address)."
    echo "  --ccp_port         Set ccp_port (default: 9009)."
    echo "  --store_port       Set store_port (default: 45678)."
    echo "  -h, --help         Show this help message and exit."
    echo "  ============= Example: ============="
    echo "  $0 # run without epx"
    echo "  $0 -u --ccp_port 9008 # run with epx and set ccp_port"
    echo "  $0 -u --master_addr xx.xxx.xx.xx # run with epx and set master_addr"
    echo "  $0 -h # Show this help message."
}

# parse command-line arguments
while [[ $# -gt 0 ]]; do
    case "$1" in
        --root_path)
            ROOT_PATH="$2"
            shift
            ;;
        --dataset_path)
            DATA_PATH="$2"
            shift
            ;;
        -u|--use_epx)
            USE_EPX=1
            ;;
        --master_addr)
            MASTER_ADDR="$2"
            shift
            ;;
        --ccp_port)
            EPX_CCP_PORT="$2"
            shift
            ;;
        --store_port)
            EPX_STORE_PORT="$2"
            shift
            ;;
        -h|--help)
            usage
            exit 0
            ;;
        *)
            echo "Unknown option: $1"
            exit 1
            ;;
    esac
    shift
done

setup_epx_env() {

    HOST_ADDR=$(ip addr show bond0 | grep -oP 'inet \K[\d.]+' | head -1)
    MASTER_ADDR=${MASER_ADDR:-$HOST_ADDR} # set master addr default to host addr

    echo "EPX HOST_ADDR: $HOST_ADDR"
    echo "EPX MASTER_ADDR: $MASTER_ADDR"

    export USE_EPX=1
    # CCP ADDR and PORT
    export EPX_CCP_ADDR="$MASTER_ADDR"
    export EPX_CCP_PORT=${EPX_CCP_PORT:-9009}
    # # STORE ADDR and PORT
    export EPX_STORE_ADDR="$MASTER_ADDR"
    export EPX_STORE_PORT=${EPX_STORE_PORT:-45678}
    # EPX_LCP_ADDR and PORT
    export EPX_LCP_ADDR="$HOST_ADDR"
    # get epx session
    # EPX_SESSION="$(uuidgen)"
    export EPX_SESSION="a0515990-ffbf-11ef-8a45-6fde377b8f7a"
    export USE_MCCL_BACKEND=1 # epx environment

    export EPX_GROUP_RANK=8

    EPX_PATH=$ROOT_PATH/epx
    EPX_STORE_PATH="${EPX_PATH}/epx-py/examples/epx_store.py"
    export PYTHONPATH=${EPX_PATH}/epx-py/python:$PYTHONPATH
    export EPX_LCP_BIN="$EPX_PATH/target/debug/epx-lcp"

    if [ "$MASTER_ADDR" = "$HOST_ADDR" ]; then
        python $EPX_STORE_PATH --addr "$HOST_ADDR" &
        STORE_PID=$!
    fi
}

set_common_env() {
    # MCCL_DEBUG=TRACE
    # export TORCH_CPP_LOG_LEVEL=INFO
    # export TORCH_DISTRIBUTED_DEBUG="DETAIL"
    # 10:DEBUG, 20:INFO, 30:WARNING, 40:ERROR, 50:FATAL
    # export MEGATRON_LOGGING_LEVEL=30
    export RUN_LOCAL=1
    export GPUS_PER_NODE=8

    export TP_SIZE=1
    export PP_SIZE=1
    export EP_SIZE=8
    export WORLD_SIZE=8
    export MICRO_BATCH_SIZE=1
    export NUM_MICROBATCHES=1
    export MOE_NUM_EXPERTS=160
    export MOE_ROUTER_GROUP_TOPK=1

    export DATA_PATH=${DATA_PATH:-"$ROOT_PATH/llama2_dataset/llama_00_text_document"}
    export TOKENIZED_MODEL="$ROOT_PATH/llama2_dataset/tokenizer.model"

    SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
}

if [[ "$USE_EPX" -ne 0 ]]; then
    setup_epx_env
fi

set_common_env

"$SCRIPT_DIR"/run_deepseekv2.sh &
R0_PID=$!

wait $R0_PID

if [[ "$USE_EPX" -ne 0 && "$MASTER_ADDR" = "$HOST_ADDR" ]]; then
    kill -9 $STORE_PID
    pkill -f $EPX_STORE_PATH
    pkill -f $EPX_STORE_PATH
fi
